package com.s54;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.s54.common.beans.AuthUser;
import com.s54.common.beans.R;
import com.s54.common.enums.SecurityEnum;
import com.s54.common.service.TokenService;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHeaders;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;

@Slf4j
@Component
@RequiredArgsConstructor
public class AuthFilter implements GlobalFilter {
    private final TokenService tokenService;
    private final IgnoredUrlsProperties ignoredUrlsProperties;
    private AntPathMatcher antPathMatcher = new AntPathMatcher();
    @Value("${demo:false}")
    private boolean demo;
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String path = request.getURI().getPath();
        String method = request.getMethodValue();
        String token = request.getHeaders().getFirst(SecurityEnum.HEADER_TOKEN.getValue());
        String userAgent = request.getHeaders().getFirst(HttpHeaders.USER_AGENT);
        if (log.isDebugEnabled()) {
            log.debug("{},{},{}", method, path, token);
        }
        if (isIgnored(path)) {
            return chain.filter(exchange);
        }
        if (!StringUtils.hasText(token)) {
            return error(exchange.getResponse(),HttpStatus.UNAUTHORIZED, "请先登录");
        }
        if (!tokenService.isValid(token, userAgent)) {
            return error(exchange.getResponse(), HttpStatus.UNAUTHORIZED, "无效令牌/用户在其他设备登录");
        }
        if (hasPerms(method, path, token, userAgent)) {
            if (demo && !"GET".equals(method)) {
                return error(exchange.getResponse(), HttpStatus.UNAUTHORIZED, "演示模式不能修改数据");
            }
            return chain.filter(exchange);
        }
        return error(exchange.getResponse(),HttpStatus.FORBIDDEN, "禁止访问");
    }

    @SneakyThrows
    private static Mono<Void> error(ServerHttpResponse response, HttpStatus httpStatus, String errorMessage) {
        response.setStatusCode(httpStatus);
        response.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8");
        R err = R.err(httpStatus.value(), errorMessage);
        ObjectMapper objectMapper = new ObjectMapper();
        String json = objectMapper.writeValueAsString(err);
        DataBuffer wrap = response.bufferFactory().wrap(json.getBytes(StandardCharsets.UTF_8));
        return response.writeWith(Flux.just(wrap));
    }

    private boolean isIgnored(String path) {
        for (String pattern : ignoredUrlsProperties.getUrls()) {
            if (antPathMatcher.match(pattern, path)) {
                return true;
            }
        }
        return false;
    }

    private boolean hasPerms(String method, String path, String token, String userAgent) {
        AuthUser authUser = tokenService.getUser(token, userAgent);
        if (authUser == null) return false;
        String actualApiPath = method + "_" + path;
        for (String pattern : authUser.getApiPaths()) {
            if (antPathMatcher.match(pattern, actualApiPath)) {
                return true;
            }
        }
        return false;
    }

    public static void main(String[] args) {
        System.out.println("String.join(\".\", \"a\", \"b\", \"c\") = " + String.join(".", "a", "b", "c"));
        AntPathMatcher matcher = new AntPathMatcher();
        String pattern = "/user/**";
        System.out.println("matcher.match(pattern, \"/user/login\") = " + matcher.match(pattern, "/user/login"));
        System.out.println("matcher.match(pattern, \"/user/logout\") = " + matcher.match(pattern, "/user/logout"));
        System.out.println("matcher.match(pattern, \"/user/10\") = " + matcher.match(pattern, "/user/10"));
        pattern = "/system/users/{id}";
        System.out.println("matcher.match(pattern, \"/system/users/10\") = " + matcher.match(pattern, "/system/users/10"));
    }
}
